iT邦幫忙

2022 iThome 鐵人賽

0
AI & Data

JAX 好好玩系列 第 34

JAX 好好玩 (34) : 類別與 jit (2) : 註冊類別為 pytree

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載 )

另外一個讓 class 型別和 jit (以及 JAX 其他 API) 相容的方法,是註冊此一 class 為 pytree 容器 (container)。JAX 既定的 pytree 容器有 list, tuple 和 dict,使用者自訂的 class 會被視為葉節點 (leaf node) ,要讓一個自訂 class 成為一個 pytree 容器,必須要:

  • 在 class 內定義並實作兩個「方法 method」。
    • 方法一:攤平 class 的方法。
    • 方法二:重組 class 的方法。
  • 註冊此 class 為 pytree。

攤平及重組

由 pytree 對於 list, tuple 及 dict 攤平和重組的操作,我們可以學到如何在 class 中實作這兩個方法。pytree 提供的 API 分別是 tree_flatten() 和 tree_unflatten(),tree_flatten() 傳入一個 pytree 變數,傳回兩個值,第一個是攤平的葉節點 list,第二個是樹結構定義 (PyTreeDef) ;tree_unflatten() 傳入樹結構定義和攤平的葉節點 list,傳回一個 pytree。從以下的程式段可以知道它們的用法:

# 宣告一個 pytree 變數
tree_01 = [1.,(2., 3.),{'age':55, 'name':'edward'}]
# 用 tree_flatten() 來取出葉節點及 pytree 的結構
flat_leaves_01, flat_struct_01 = jtree.tree_flatten(tree_01)
 
print(type(flat_leaves_01), type(flat_struct_01))
print("=================================================")
print(flat_leaves_01)
print("=================================================")
print(flat_struct_01)

output :
https://ithelp.ithome.com.tw/upload/images/20221012/20129616z9cMVe2mak.png

# 用 tree_unflatten() 來重組 pytree
tree_unflat_01 = jtree.tree_unflatten(flat_struct_01, flat_leaves_01)
print(tree_unflat_01)

output :
[1.0, (2.0, 3.0), {'age': 55, 'name': 'edward'}]

自訂 class

要讓 class 成為一個 pytree 容器,首先得在 class 內實作符合這個 class 的 tree_flatten() 和 tree_unflatten() 方法,然後,再呼叫 register_pytree_node() 來註冊這個類別及其特有的方法就可以了:

# 定義一個 user-defined class, 並定義其 _tree_flatten, _tree_unflatten()
# ====================================================================================
 
class MyClass03():
    def __init__(self, x=1.0, y=1.0):
        self.x = x
        self.y = y
 
    def _tree_flatten(self):
        children = (self.x,self.y)  # arrays / dynamic values
        aux_data = None  # aux_data 要傳回重組此 class 的資訊,以這個例子來說
                         # 傳回 children 就夠了, 因此, 設為 None 即可
        return (children, aux_data)
 
    @classmethod # 注意! 這個修飾是必要的
    def _tree_unflatten(cls, aux_data, children): 
        return cls(*children) # 呼叫 class 的 __init__ 來建構新的 class 案例. 此例中
                              # 不需要 aux_data.
 
 
# 不必再宣告為 partial
@jax.jit
def my_func_03(cls: MyClass03, addition):
    return cls.x + cls.y + addition
# register MyClass03
 
jtree.register_pytree_node(MyClass03,
                           MyClass03._tree_flatten,
                           MyClass03._tree_unflatten)
my_class03 = MyClass03(1.0, 2.0)
 
print(f'Before Modification: {my_func_03(my_class03, 3.0)}')
print(f'                     {hash(my_class03)}')
 
# 修改 my_class 的內容
my_class03.x = 3.0
my_class03.y = 4.0
print(f'After Modification: {my_func_03(my_class03, 3.0)}')
print(f'                    {hash(my_class03)}')

output :
https://ithelp.ithome.com.tw/upload/images/20221012/20129616jdCGh2c0BP.png

在程式中有相關的說明,但另外要注意的,my_class03 在修改前後其 hash() 值保持不變,因為我們並沒有實作 hash,但此並不會影響 jax tracing,只要將 MyClass03 註冊為 pytree 即可。


上一篇
JAX 好好玩 (33) : 類別與 jit (1) : 重新定義 hash
下一篇
JAX 好好玩 (35) : Flax (1) : 準備學習 Flax
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言